skip_connection("dplyr")

test_requires("dplyr")

sc <- testthat_spark_connection()

iris_tbl <- testthat_tbl("iris")
mtcars_tbl <- testthat_tbl("mtcars")

has_predicates <- tidyselect_data_has_predicates(mtcars_tbl)



df1 <- tibble(a = 1:3, b = letters[1:3])
df2 <- tibble(b = letters[1:3], c = letters[24:26])

df1_tbl <- testthat_tbl("df1")
df2_tbl <- testthat_tbl("df2")

sdf_5 <- copy_to(sc, data.frame(id = 1:5))
sdf_10 <- copy_to(sc, data.frame(id = 1:10))

dplyr_across_test_cases_df <- tibble(
  x = seq(3),
  y = as.character(seq(3)),
  t = as.POSIXct(seq(3), origin = "1970-01-01"),
  z = seq(3) + 5L
)
dplyr_across_test_cases_tbl <- testthat_tbl("dplyr_across_test_cases_df")

test_remote_name <- function(x, y) {
  if (packageVersion("dbplyr") <= "2.3.4") {
    y <- ident(y)
  }
  expect_equal(dbplyr::remote_name(x), y)
}

scalars_df <- dplyr::tibble(
  row_num = seq(4),
  b_a = c(FALSE, FALSE, TRUE, TRUE),
  b_b = c(FALSE, TRUE, FALSE, TRUE),
  ba = FALSE,
  bb = TRUE,
  n_a = c(2, 3, 6, 7),
  n_b = c(3, 6, 2, 7),
  c_a = c("aa", "ab", "ca", "dd"),
  c_b = c("ab", "bc", "ac", "ad")
)
scalars_sdf <- copy_to(sc, scalars_df, overwrite = TRUE)

arrays_df <- dplyr::tibble(
  row_num = seq(4),
  a_a = list(1:4, 2:5, 3:6, 4:7),
  a_b = list(4:7, 3:6, 2:5, 1:4)
)
arrays_sdf <- copy_to(sc, arrays_df, overwrite = TRUE)

test_that("'select' works with where(...) predicate", {

  skip_if(!has_predicates)

  expect_equal(
    iris %>% select(where(is.numeric)) %>% tbl_vars() %>% gsub("\\.", "_", .),
    iris_tbl %>% select(where(is.numeric)) %>% collect() %>% tbl_vars()
  )
})

test_that("'n_distinct' summarizer works as expected", {
  skip_connection("supports-na")
  summarize_n_distinct <- function(input) {
    input %>%
      summarize(
        n_distinct_default = n_distinct(x ^ 2),
        n_distinct_na_rm_true = n_distinct(x ^ 2, na.rm = TRUE),
        n_distinct_na_rm_false = n_distinct(x ^ 2, na.rm = FALSE)
      )
  }

  df <- dplyr::tibble(x = c(-3L:2L, NA, NaN, NA))
  sdf <- copy_to(sc, df, name = random_string())

  expect_equal(
    df %>% summarize_n_distinct(),
    sdf %>% summarize_n_distinct() %>% collect(),
    ignore_attr = TRUE
  )
})

test_that("'summarize' works with where(...) predicate", {
  skip_if(!has_predicates)

  expect_equivalent(
    iris %>% summarize(across(where(is.numeric), mean)),
    iris_tbl %>% summarize(across(where(is.numeric), ~mean(.x, na.rm = TRUE))) %>% collect()
  )

  expect_equivalent(
    iris %>% summarize(across(starts_with("Petal"), mean)),
    iris_tbl %>% summarize(across(starts_with("Petal"), ~mean(.x, na.rm = TRUE))) %>%  collect()
  )

  expect_equivalent(
    iris %>% summarize(across(where(is.factor), n_distinct)),
    iris_tbl %>% summarize(across(where(is.character), n_distinct)) %>% collect()
  )
})

test_that("'mutate' works as expected", {
  expect_equal(
    iris %>% mutate(x = Species) %>% tbl_vars() %>% gsub("\\.", "_", .),
    iris_tbl %>% mutate(x = Species) %>% collect() %>% tbl_vars()
  )
})


test_that("'mutate' and 'transmute' work with NSE", {
  col <- "mpg"
  expect_equivalent(
    mtcars_tbl %>% mutate(!!col := !!rlang::sym(col) * 2) %>% collect(),
    mtcars %>% mutate(!!col := !!rlang::sym(col) * 2)
  )
  expect_equivalent(
    mtcars_tbl %>% transmute(!!col := !!rlang::sym(col) * 2) %>% collect(),
    mtcars %>% transmute(!!col := !!rlang::sym(col) * 2)
  )
})

test_that("the implementation of 'filter' functions as expected", {
  expect_equivalent(
    iris_tbl %>%
      filter(Sepal_Length == 5.1) %>%
      filter(Sepal_Width == 3.5) %>%
      filter(Petal_Length == 1.4) %>%
      filter(Petal_Width == 0.2) %>%
      select(Species) %>%
      collect(),
    iris %>%
      transmute(
        Sepal_Length = `Sepal.Length`,
        Sepal_Width = `Sepal.Width`,
        Petal_Length = `Petal.Length`,
        Petal_Width = `Petal.Width`,
        Species = Species
      ) %>%
      filter(Sepal_Length == 5.1) %>%
      filter(Sepal_Width == 3.5) %>%
      filter(Petal_Length == 1.4) %>%
      filter(Petal_Width == 0.2) %>%
      transmute(Species = as.character(Species))
  )
})

test_that("if_else works as expected", {
  sdf <- copy_to(sc, dplyr::tibble(x = c(0.9, NA_real_, 1.1)))

  expect_equal(
    sdf %>% dplyr::mutate(x = ifelse(x > 1, "good", "bad")) %>% dplyr::pull(x),
    c("bad", NA, "good")
  )
  expect_equal(
    sdf %>% dplyr::mutate(x = ifelse(x > 1, "good", "bad", "unknown")) %>%
      dplyr::pull(x),
    c("bad", "unknown", "good")
  )
})

test_that("if_all and if_any work as expected", {
  test_requires_package_version("dbplyr", 2)
  expect_equivalent(
    scalars_sdf %>%
      filter(if_any(starts_with("b_"))) %>%
      collect(),
    scalars_df %>%
      filter(if_any(starts_with("b_")))
  )

  expect_equivalent(
    scalars_sdf %>%
      filter(if_all(starts_with("b_"))) %>%
      collect(),
    scalars_df %>%
      filter(if_all(starts_with("b_")))
  )
})

test_that("if_all and if_any work as expected with boolean predicates", {
  test_requires_package_version("dbplyr", 2)
  test_requires_version("2.4.0")
  skip_on_arrow()

  expect_equivalent(
    scalars_sdf %>%
      filter(if_all(starts_with("n_"), ~ .x > 5)) %>%
      collect(),
    scalars_df %>% filter(if_all(starts_with("n_"), ~ .x > 5))
  )

  expect_equivalent(
    scalars_sdf %>%
      filter(if_any(starts_with("n_"), ~ .x > 5)) %>%
      collect(),
    scalars_df %>% filter(if_any(starts_with("n_"), ~ .x > 5))
  )

  expect_equivalent(
    scalars_sdf %>%
      filter(if_all(starts_with("n_"), c(~ .x > 5, ~ .x < 3))) %>%
      collect(),
    scalars_df %>% filter(if_all(starts_with("n_"), c(~ .x > 5, ~ .x < 3)))
  )

  expect_equivalent(
    scalars_sdf %>%
      filter(if_any(starts_with("n_"), c(~ .x > 6, ~ .x < 3))) %>%
      collect(),
    scalars_df %>% filter(if_any(starts_with("n_"), c(~ .x > 6, ~ .x < 3)))
  )

  # if_all/if_any is totally dependent on dbplyr implementation
  # there is a warning that does not seem to be
  # generated by sparklyr code
  expect_warning(
    scalars_sdf %>%
      dplyr::filter(if_all(starts_with("c_"), grepl, "caabac"))
  )

  expect_equivalent(
    scalars_sdf %>%
      dplyr::filter(if_all(starts_with("c_"), grepl, "caabac")) %>%
      pull(row_num),
    c(1L, 3L)
  )

  expect_equivalent(
    scalars_sdf %>%
      dplyr::filter(if_any(starts_with("c_"), grepl, "aac")) %>%
      pull(row_num),
    c(1L, 3L)
  )

  expect_equivalent(
    scalars_sdf %>%
      dplyr::filter(if_any(starts_with("c_"), grepl, "bcad")) %>%
      pull(row_num),
    c(2L, 3L, 4L)
  )

  expect_equivalent(
    arrays_sdf %>%
      filter(if_all(starts_with("a_"), ~ array_contains(.x, 5L))) %>%
      pull(row_num),
    c(2L, 3L)
  )

  expect_equivalent(
    arrays_sdf %>%
      filter(if_any(starts_with("a_"), ~ array_contains(.x, 7L))) %>%
      pull(row_num),
    c(1L, 4L)
  )
})

test_that("grepl works as expected", {
  regexes <- c(
    "a|c", ".", "b", "x|z", "", "y", "e", "^", "$", "^$", "[0-9]", "[a-z]", "[b-z]"
  )
  verify_equivalent <- function(actual, expected) {
    # handle an edge case for arrow-enabled Spark connection
    for (col in colnames(df2)) {
      expect_equivalent(
        as.character(actual[[col]]),
        as.character(expected[[col]])
      )
    }
  }
  for (regex in regexes) {
    verify_equivalent(
      df2 %>% dplyr::filter(grepl(regex, b)),
      df2_tbl %>% dplyr::filter(grepl(regex, b)) %>% collect()
    )
    verify_equivalent(
      df2 %>% dplyr::filter(grepl(regex, c)),
      df2_tbl %>% dplyr::filter(grepl(regex, c)) %>% collect()
    )
  }
})

test_that("'head' uses 'limit' clause", {

  test_requires("dbplyr")

  expect_true(
    grepl(
      "LIMIT",
      sql_render(head(iris_tbl))
    )
  )
})

test_that("'sdf_broadcast' forces broadcast hash join", {
  skip_connection("sdf-broadcast")
  query_plan <- df1_tbl %>%
    sdf_broadcast() %>%
    left_join(df2_tbl, by = "b") %>%
    spark_dataframe() %>%
    invoke("queryExecution") %>%
    invoke("analyzed") %>%
    invoke("toString")
  expect_match(query_plan, "B|broadcast")
})

test_that("compute() works as expected", {
  sdf <- sdf_10
  sdf_even <- sdf %>% dplyr::filter(id %% 2 == 0)
  sdf_odd <- sdf %>% dplyr::filter(id %% 2 == 1)

  expect_null(dbplyr::remote_name(sdf_even))
  expect_null(dbplyr::remote_name(sdf_odd))

  # caching Spark dataframes with random names
  sdf_even_cached <- sdf_even %>% dplyr::compute()
  sdf_odd_cached <- sdf_odd %>% dplyr::compute()

  expect_equivalent(
    sdf_even_cached %>% collect(),
    dplyr::tibble(id = c(2L, 4L, 6L, 8L, 10L))
  )
  expect_equivalent(
    sdf_odd_cached %>% collect(),
    dplyr::tibble(id = c(1L, 3L, 5L, 7L, 9L))
  )

  # caching Spark dataframes with pre-determined names
  sdf_congruent_to_1_mod_3 <- sdf %>% dplyr::filter(id %% 3 == 1)
  sdf_congruent_to_2_mod_3 <- sdf %>% dplyr::filter(id %% 3 == 2)

  expect_null(sdf_congruent_to_1_mod_3 %>% dbplyr::remote_name())
  expect_null(sdf_congruent_to_2_mod_3 %>% dbplyr::remote_name())

  sdf_congruent_to_1_mod_3_cached <- sdf_congruent_to_1_mod_3 %>%
    dplyr::compute(name = "congruent_to_1_mod_3")
  sdf_congruent_to_2_mod_3_cached <- sdf_congruent_to_2_mod_3 %>%
    dplyr::compute(name = "congruent_to_2_mod_3")

  test_remote_name(
    sdf_congruent_to_1_mod_3_cached,
    "congruent_to_1_mod_3"
  )
  test_remote_name(
    sdf_congruent_to_2_mod_3_cached,
    "congruent_to_2_mod_3"
  )

  temp_view <- sdf_congruent_to_2_mod_3 %>% dplyr::compute("temp_view")

  test_remote_name(
    temp_view, "temp_view"
  )

  expect_equivalent(
    sdf_congruent_to_1_mod_3_cached %>% collect(),
    dplyr::tibble(id = c(1L, 4L, 7L, 10L))
  )
  expect_equivalent(
    sdf_congruent_to_2_mod_3_cached %>% collect(),
    dplyr::tibble(id = c(2L, 5L, 8L))
  )
})

test_that("mutate creates NA_real_ column correctly", {
  sdf <- sdf_5 %>% dplyr::mutate(z = NA_real_, sq = id * id)

  expect_equivalent(
    sdf %>% collect(),
    dplyr::tibble(id = seq(5), z = NA_real_, sq = id * id)
  )
})

test_that("transmute creates NA_real_ column correctly", {
  sdf <- sdf_5 %>% dplyr::transmute(z = NA_real_, sq = id * id)

  expect_equivalent(
    sdf %>% collect(),
    dplyr::tibble(z = NA_real_, sq = seq(5) * seq(5))
  )
})

test_that("overwriting a temp view", {
  # Skipping while researching why override works on non-connect methods
  skip()
  temp_view_name <- random_string()

  sdf <- sdf_5 %>%
    dplyr::mutate(foo = "foo") %>%
    dplyr::compute(name = temp_view_name)
  sdf <- sdf_5 %>%
    dplyr::compute(name = temp_view_name)

  expect_equivalent(sdf %>% collect(), dplyr::tibble(id = seq(5)))
  expect_equivalent(
    dplyr::tbl(sc, temp_view_name) %>% collect(), dplyr::tibble(id = seq(5))
  )
})

test_that("dplyr::distinct() impl is configurable", {
  options(sparklyr.dplyr_distinct.impl = "tbl_lazy")
  on.exit(options(sparklyr.dplyr_distinct.impl = NULL))

  tbl_name <- random_string()
  sdf <- copy_to(sc, data.frame(a = c(1, 1)), name = tbl_name)

  query <- sdf %>%
    dplyr::distinct() %>%
    dbplyr::remote_query() %>%
    strsplit("\\s+")

  query[[1]][[3]] <- gsub(sprintf("`%s`.*", tbl_name), "*", query[[1]][[3]])

  expect_equal(
    toupper(query[[1]]),
    c("SELECT", "DISTINCT", "*", "FROM", sprintf("`%s`", toupper(tbl_name)))
  )
  expect_equivalent(
    sdf %>% dplyr::distinct() %>% collect(),
    data.frame(a = 1)
  )
})

test_that("process_tbl_name works as expected", {
  skip_if(any(grepl("connect_", class(sc))))
  expect_equal(sparklyr:::process_tbl_name("a"), "a")
  expect_equal(sparklyr:::process_tbl_name("xyz"), "xyz")
  expect_equal(sparklyr:::process_tbl_name("x.y"), dbplyr::in_schema("x", "y"))
  expect_equal(sparklyr:::process_tbl_name("x.y.z"), dbplyr::in_catalog("x", "y", "z"))

  df1 <- dplyr::tibble(a = 1, g = 2) %>%
    copy_to(sc, ., "df1", overwrite = TRUE)
  df2 <- dplyr::tibble(b = 1, g = 2) %>%
    copy_to(sc, ., "df2", overwrite = TRUE)

  query <- sql("SELECT df1.a, df2.b, df1.g FROM df1 LEFT JOIN df2 ON df1.g = df2.g")
  expect_equivalent(
    tbl(sc, query) %>% collect(),
    dplyr::tibble(a = 1, b = 1, g = 2)
  )

})

test_that("in_schema() works as expected", {
  skip_on_arrow()
  skip_on_livy()
  if(spark_version(sc) < "3.4.0") {
    db_name <- random_string("test_db_")

    queries <- c(
      sprintf("CREATE DATABASE `%s`", db_name),
      sprintf(
        "CREATE TABLE IF NOT EXISTS `%s`.`hive_tbl` (`x` INT) USING hive",
        db_name
      )
    )
    for (query in queries) {
      DBI::dbGetQuery(sc, query)
    }

    expect_equivalent(
      dplyr::tbl(sc, dbplyr::in_schema(db_name, "hive_tbl")) %>% collect(),
      dplyr::tibble(x = integer())
    )
  }
})

test_that("sdf_remote_name returns null for computed tables", {
  test_remote_name(iris_tbl, "iris")

  virginica_sdf <- iris_tbl %>% filter(Species == "virginica")
  expect_equal(dbplyr::remote_name(virginica_sdf), NULL)
})

test_that("sdf_remote_name ignores the last group_by() operation(s)", {
  sdf <- iris_tbl
  for (i in seq(4)) {
    sdf <- sdf %>% dplyr::group_by(Species)
    test_remote_name(sdf, "iris")
  }
})

test_that("sdf_remote_name ignores the last ungroup() operation(s)", {
  sdf <- iris_tbl
  for (i in seq(4)) {
    sdf <- sdf %>% dplyr::ungroup()
    test_remote_name(sdf, "iris")
  }
})

test_that("sdf_remote_name works with arrange followed by compute", {
  tbl <- copy_to(sc, dplyr::tibble(lts = letters[26:24], nums = seq(3)))
  ordered_tbl <- tbl %>% arrange(lts) %>% compute(name = "ordered_tbl")

test_remote_name(
    ordered_tbl,
    "ordered_tbl"
  )
  expect_equivalent(
    tbl(sc, "ordered_tbl") %>% collect(),
    dplyr::tibble(lts = letters[24:26], nums = 3:1)
  )
})

test_that("result from dplyr::compute() has remote name", {
  sdf <- iris_tbl
  sdf <- sdf %>% dplyr::mutate(y = 5) %>% dplyr::compute()
  expect_false(is.null(sdf %>% dbplyr::remote_name()))
})

test_that("tbl_ptype.tbl_spark works as expected", {
  skip_if(!has_predicates)
  expect_equal(df1_tbl %>% dplyr::select_if(is.integer) %>% colnames(), "a")
  expect_equal(df1_tbl %>% dplyr::select_if(is.numeric) %>% colnames(), "a")
  expect_equal(df1_tbl %>% dplyr::select_if(is.character) %>% colnames(), "b")
  expect_equal(df1_tbl %>% dplyr::select_if(is.list) %>% colnames(), character())
})

test_that("summarise(.groups=)", {
  sdf <- copy_to(sc, data.frame(x = 1, y = 2)) %>%
    group_by(x, y)

  expect_equal(sdf %>% summarise() %>% group_vars(), "x")
  expect_equal(sdf %>% summarise(.groups = "drop_last") %>% group_vars(), "x")
  expect_equal(sdf %>% summarise(.groups = "drop") %>% group_vars(), character())
  expect_equal(sdf %>% summarise(.groups = "keep") %>% group_vars(), c("x", "y"))

  df <- dplyr::tibble(val1 = c(1, 2, 1, 2), val2 = c(10, 20, 30, 40))
  sdf <- copy_to(sc, df, name = random_string())
  for (groups in c("drop_last", "drop", "keep")) {
    expect_equivalent(
      sdf %>%
        group_by(val1) %>%
        summarize(result = sum(val2, na.rm = TRUE), .groups = groups) %>%
        arrange(val1) %>%
        collect(),
      df %>%
        group_by(val1) %>%
        summarize(result = sum(val2, na.rm = TRUE), .groups = groups) %>%
        arrange(val1)

    )
  }
})

test_that("tbl_spark prints", {
  print_output <- capture.output(print(iris_tbl))
  expect_equal(
    print_output[1],
    "# Source:   table<`iris`> [?? x 5]"
  )
})


test_that("pmin and pmax work", {
  pmin_df <- data.frame(x = 11:20, y = 1:10)

  tbl_pmin_df <- sdf_copy_to(sc, pmin_df, overwrite = TRUE)

  remote_p <- tbl_pmin_df %>%
    mutate(
      p_min = pmin(x, y),
      p_max = pmax(x, y)
    ) %>%
    collect()

  local_p <- pmin_df %>%
    mutate(
      p_min = pmin(x, y),
      p_max = pmax(x, y)
    )

  expect_true(
    all(remote_p == local_p)
  )

  expect_error({
    collect(mutate(tbl_pmin_df, x = pmin(x, y, na.rm = FALSE)))
  }, regexp = "na.rm = TRUE")

  expect_error({
    collect(mutate(tbl_pmin_df, x = pmax(x, y, na.rm = FALSE)))
  }, regexp = "na.rm = TRUE")
})

test_clear_cache()
